{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pprint as pp\n",
    "from sklearn.model_selection import train_test_split\n",
    "import pickle\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.autograd import Variable\n",
    "import torch.utils.data as data_utils\n",
    "import matplotlib.pylab as plt\n",
    "import sys, os\n",
    "sys.path.append(os.path.join(os.path.dirname(\"__file__\"), '..'))\n",
    "from pytorch_net.net import MLP, ConvNet, load_model_dict, train, get_Layer\n",
    "from pytorch_net.util import Early_Stopping, get_param_name_list, get_variable_name_list, standardize_symbolic_expression"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Apart from section 0, each section is independent on its own"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 0. Preparing dataset:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Preparing some toy dataset:\n",
    "X = np.random.randn(1000,1)\n",
    "y = X ** 2\n",
    "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.2)\n",
    "X_train = Variable(torch.FloatTensor(X_train))\n",
    "y_train = Variable(torch.FloatTensor(y_train))\n",
    "X_test = Variable(torch.FloatTensor(X_test))\n",
    "y_test = Variable(torch.FloatTensor(y_test))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Constuct a simple MLP:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Constuct the network:\n",
    "input_size = 1\n",
    "struct_param = [\n",
    "    [2, \"Simple_Layer\", {}],   # (number of neurons in each layer, layer_type, layer settings)\n",
    "    [400, \"Simple_Layer\", {\"activation\": \"relu\"}],\n",
    "    [1, \"Simple_Layer\", {\"activation\": \"linear\"}],\n",
    "]\n",
    "settings = {\"activation\": \"relu\"} # Default activation if the activation is not specified in \"struct_param\" in each layer.\n",
    "                                    # If the activation is specified, it will overwrite this default settings.\n",
    "\n",
    "net = MLP(input_size = input_size,\n",
    "          struct_param = struct_param,\n",
    "          settings = settings,\n",
    "         )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the prediction of the Net:\n",
    "net(X_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get intermediate activation of the net:\n",
    "net.inspect_operation(X_train,                   # Input\n",
    "                      operation_between = (0,2), # Operation_between selects the subgraph. \n",
    "                                                 # Here (0, 2) means that the inputs feeds into layer_0 (first layer)\n",
    "                                                 # and before the layer_2 (third layer).\n",
    "                     )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save model:\n",
    "pickle.dump(net.model_dict, open(\"net.p\", \"wb\"))  # net.model_dict constains all the information (structural and parameters) of the network\n",
    "\n",
    "# Load model:\n",
    "net_loaded = load_model_dict(pickle.load(open(\"net.p\", \"rb\")))\n",
    "\n",
    "# Check the loaded net and the original net is identical:\n",
    "net_loaded(X_train) - net(X_train)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Using symbolic layers, and simplification:\n",
    "### 2.1 Constructing MLP consisting of symbolic layers:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Construct the network:\n",
    "model_dict = {\n",
    "    \"type\": \"MLP\",\n",
    "    \"input_size\": 4,\n",
    "    \"struct_param\": [[2, \"Symbolic_Layer\", {\"symbolic_expression\": \"[3 * x0 ** 2 + p0 * x1 * x2 + p1 * x3, 5 * x0 ** 2 + p2 * x1 + p3 * x3 * x2]\"}],\n",
    "                     [1, \"Symbolic_Layer\", {\"symbolic_expression\": \"[3 * x0 ** 2 + p2 * x1]\"}], \n",
    "                    ],\n",
    "    # Here the optional \"weights\" sets up the initial values for the parameters. If not set, will initialize with N(0, 1):\n",
    "    'weights': [{'p0': -1.3,\n",
    "                 'p1': 1.0,\n",
    "                 'p2': 2.3,\n",
    "                 'p3': -0.4},\n",
    "                {'p2': -1.5},\n",
    "               ]\n",
    "}\n",
    "net = load_model_dict(model_dict)\n",
    "pp.pprint(net.model_dict)\n",
    "print(\"\\nOutput:\")\n",
    "net(torch.rand(100, 4))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.2 Simplification of an MLP from Simple_Layer to Symbolic_Layer:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_size = 1\n",
    "struct_param = [\n",
    "    [2, \"Simple_Layer\", {\"activation\": \"relu\"}],   # (number of neurons in each layer, layer_type, layer settings)\n",
    "    [10, \"Simple_Layer\", {\"activation\": \"linear\"}],\n",
    "    [1, \"Simple_Layer\", {\"activation\": \"relu\"}],\n",
    "]\n",
    "\n",
    "net = MLP(input_size = input_size,\n",
    "          struct_param = struct_param,\n",
    "          settings = {},\n",
    "         )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "net.simplify(X=X_train,\n",
    "             y=y_train,\n",
    "             mode=['to_symbolic'], \n",
    "             # The mode is a list of consecutive simplification methods, choosing from:\n",
    "             # 'collapse_layers': collapse multiple Simple_Layer with linear activation into a single Simple_Layer; \n",
    "             # 'local': greedily try reducing the input dimension by removing input dimension from the beginning\n",
    "             # 'snap': greedily snap each float parameter into an integer or rational number. Set argument 'snap_mode' == 'integer' or 'rational';\n",
    "             # 'pair_snap': greedily trying if the ratio of a pair of parameters is an integer or rational number (by setting snap_mode)\n",
    "             # 'activation_snap': snap the activation;\n",
    "             # 'to_symbolic': transform the Simple_Layer into Symbolic_layer;\n",
    "             # 'symbolic_simplification': collapse multiple layers of Symbolic_Layer into a single Symbolic_Layer;\n",
    "             # 'ramping-L1': increasing L1 regularization for the parameters and train. When some parameter is below a threshold, snap it to 0.\n",
    "            )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. SuperNet:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.1 Use SuperNet Layer in your own module:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "layer_dict = {\n",
    "    \"layer_type\": \"SuperNet_Layer\",\n",
    "    \"input_size\": 1,\n",
    "    \"output_size\": 2,\n",
    "    \"settings\": {\n",
    "                 \"W_available\": [\"dense\", \"Toeplitz\"], # Weight type. Choose subset of \"dense\", \"Toeplitz\", \"arithmetic-series-in\", \"arithmetic-series-out\", \"arithmetic-series-2D-in\", \"arithmetic-series-2D-out\"\n",
    "                 \"b_available\": [\"dense\", \"None\", \"arithmetic-series\", \"constant\"], # Bias type. Choose subsets of \"None\", \"constant\", \"arithmetic-series\", \"arithmetic-series-2D\"\n",
    "                 \"A_available\": [\"linear\", \"relu\"], # Activation. Choose subset of \"linear\", \"relu\", \"leakyRelu\", \"softplus\", \"sigmoid\", \"tanh\", \"selu\", \"elu\", \"softmax\"\n",
    "                }\n",
    "}\n",
    "SuperNet_Layer = get_Layer(**layer_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.layer_0 = SuperNet_Layer\n",
    "        self.layer_1 = nn.Linear(2,4)\n",
    "    \n",
    "    def forward(self, X):\n",
    "        return self.layer_1(self.layer_0(X))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "net = Net()\n",
    "net(X_train)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.2 Construct an MLP containing SuperNet layer:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_size = 1\n",
    "struct_param = [\n",
    "    [6, \"Simple_Layer\", {}],   # (number of neurons in each layer, layer_type, layer settings)\n",
    "    [4, \"SuperNet_Layer\", {\"activation\": \"relu\", # Choose from \"linear\", \"relu\", \"leakyRelu\", \"softplus\", \"sigmoid\", \"tanh\", \"selu\", \"elu\", \"softmax\"\n",
    "                           \"W_available\": [\"dense\", \"Toeplitz\"],\n",
    "                           \"b_available\": [\"dense\", \"None\", \"arithmetic-series\", \"constant\"],\n",
    "                           \"A_available\": [\"linear\", \"relu\"],\n",
    "                          }],\n",
    "    [1, \"Simple_Layer\", {\"activation\": \"linear\"}],\n",
    "]\n",
    "settings = {\"activation\": \"relu\"} # Default activation if the activation is not specified in \"struct_param\" in each layer.\n",
    "                                    # If the activation is specified, it will overwrite this default settings.\n",
    "\n",
    "net = MLP(input_size = input_size,\n",
    "          struct_param = struct_param,\n",
    "          settings = settings,\n",
    "         )\n",
    "net(X_train)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Training using explicit commands:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "# training settings:\n",
    "batch_size = 128\n",
    "epochs = 500\n",
    "\n",
    "# Prepare training set batches:\n",
    "dataset_train = data_utils.TensorDataset(X_train.data, y_train.data)   #  The data_loader must use the torch Tensor, not Variable. So I use X_train.data to get the Tensor.\n",
    "train_loader = data_utils.DataLoader(dataset_train, batch_size = batch_size, shuffle = True)\n",
    "\n",
    "# Set up optimizer:\n",
    "optimizer = optim.Adam(net.parameters(), lr = 1e-3)\n",
    "# Get loss function. Choose from \"mse\" or \"huber\", etc.\n",
    "criterion = nn.MSELoss()\n",
    "# Set up early stopping. If the validation loss does not go down after \"patience\" number of epochs, early stop.\n",
    "early_stopping = Early_Stopping(patience = 10) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "to_stop = False\n",
    "for epoch in range(epochs):\n",
    "    for batch_id, (X_batch, y_batch) in enumerate(train_loader):\n",
    "        # Every learning step must contain the following 5 steps:\n",
    "        optimizer.zero_grad()   # Zero out the gradient buffer\n",
    "        pred = net(Variable(X_batch))   # Obtain network's prediction\n",
    "        loss_train = criterion(pred, Variable(y_batch))  # Calculate the loss\n",
    "        loss_train.backward()    # Perform backward step on the loss to calculate the gradient for each parameter\n",
    "        optimizer.step()         # Use the optimizer to perform one step of parameter update\n",
    "        \n",
    "    # Validation at the end of each epoch:\n",
    "    loss_test = criterion(net(X_test), y_test)\n",
    "    to_stop = early_stopping.monitor(loss_test.item())\n",
    "    print(\"epoch {0} \\tbatch {1} \\tloss_train: {2:.6f}\\tloss_test: {3:.6f}\".format(epoch, batch_id, loss_train.item(), loss_test.item()))\n",
    "    if to_stop:\n",
    "        print(\"Early stopping at epoch {0}\".format(epoch))\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save model:\n",
    "pickle.dump(net.model_dict, open(\"net.p\", \"wb\"))\n",
    "\n",
    "# Load model:\n",
    "net_loaded = load_model_dict(pickle.load(open(\"net.p\", \"rb\")))\n",
    "\n",
    "# Check the loaded net and the original net is identical:\n",
    "net_loaded(X_train) - net(X_train)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. Advanced example: training MNIST using given train() function:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchvision import datasets\n",
    "import torch.utils.data as data_utils\n",
    "from pytorch_net.util import train_test_split, normalize_tensor, to_Variable\n",
    "is_cuda = torch.cuda.is_available()\n",
    "\n",
    "struct_param_conv = [\n",
    "    [64, \"Conv2d\", {\"kernel_size\": 3, \"padding\": 1}],\n",
    "    [64, \"BatchNorm2d\", {\"activation\": \"relu\"}],\n",
    "    [None, \"MaxPool2d\", {\"kernel_size\": 2}],\n",
    "\n",
    "    [64, \"Conv2d\", {\"kernel_size\": 3, \"padding\": 1}],\n",
    "    [64, \"BatchNorm2d\", {\"activation\": \"relu\"}],\n",
    "    [None, \"MaxPool2d\", {\"kernel_size\": 2}],\n",
    "\n",
    "    [64, \"Conv2d\", {\"kernel_size\": 3, \"padding\": 1}],\n",
    "    [64, \"BatchNorm2d\", {\"activation\": \"relu\"}],\n",
    "    [None, \"MaxPool2d\", {\"kernel_size\": 2}],\n",
    "\n",
    "    [64, \"Conv2d\", {\"kernel_size\": 3, \"padding\": 1}],\n",
    "    [64, \"BatchNorm2d\", {\"activation\": \"relu\"}],\n",
    "    [None, \"MaxPool2d\", {\"kernel_size\": 2}],\n",
    "    [10, \"Simple_Layer\", {\"layer_input_size\": 64}],\n",
    "]\n",
    "\n",
    "model = ConvNet(input_channels = 1,\n",
    "                struct_param = struct_param_conv,\n",
    "                is_cuda = is_cuda,\n",
    "               )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_name = \"MNIST\"\n",
    "batch_size = 128\n",
    "dataset_raw = getattr(datasets, dataset_name)('datasets/{0}'.format(dataset_name), download = True)\n",
    "X, y = to_Variable(dataset_raw.train_data.unsqueeze(1).float(), dataset_raw.train_labels, is_cuda = is_cuda)\n",
    "X = normalize_tensor(X, new_range = (0, 1))\n",
    "(X_train, y_train), (X_test, y_test) = train_test_split(X, y, test_size = 0.2) # Split into training and testing\n",
    "dataset_train = data_utils.TensorDataset(X_train, y_train)\n",
    "train_loader = data_utils.DataLoader(dataset_train, batch_size = batch_size, shuffle = True) # initialize the dataLoader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lr = 1e-3\n",
    "reg_dict = {\"weight\": 1e-6, \"bias\": 1e-6}\n",
    "loss_original, loss_value, data_record = train(model, \n",
    "                                               train_loader = train_loader,\n",
    "                                               validation_data = (X_test, y_test),\n",
    "                                               criterion = nn.CrossEntropyLoss(),\n",
    "                                               lr = lr,\n",
    "                                               reg_dict = reg_dict,\n",
    "                                               epochs = 1000,\n",
    "                                               isplot = True,\n",
    "                                               patience = 40,\n",
    "                                               scheduler_patience = 40,\n",
    "                                               inspect_items = [\"accuracy\"],\n",
    "                                               record_keys = [\"accuracy\"],\n",
    "                                               inspect_interval = 1,\n",
    "                                               inspect_items_interval = 1,\n",
    "                                               inspect_loss_precision = 4,\n",
    "                                              )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6. An example callback code:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if isplot:\n",
    "    import matplotlib.pylab as plt\n",
    "    fig = plt.figure()\n",
    "    if dim == 3:\n",
    "        ax = fig.add_subplot(111, projection='3d')\n",
    "    else:\n",
    "        fig.add_axes([0, 0, 1, 1])\n",
    "        ax = fig.axes[0]\n",
    "else:\n",
    "    ax = None\n",
    "\n",
    "\n",
    "def visualize(model, X, y, iteration, loss, ax, dim):\n",
    "    if isplot:\n",
    "        plt.cla()\n",
    "        y_pred = model.transform(X)\n",
    "        y_pred, y = to_np_array(y_pred, y)\n",
    "        if dim == 3:\n",
    "            ax.scatter(y[:,0],  y[:,1], y[:,2], color='red', label='ref', s = 1)\n",
    "            ax.scatter(y_pred[:,0],  y_pred[:,1], y_pred[:,2], color='blue', label='data', s = 1)\n",
    "            ax.text2D(0.87, 0.92, 'Iteration: {:d}\\nError: {:06.4f}'.format(iteration, loss), horizontalalignment='center', verticalalignment='center', transform=ax.transAxes, fontsize='x-large')\n",
    "        else:\n",
    "            ax.scatter(y[:,0], y[:,1], color='red', label='ref')\n",
    "            ax.scatter(y_pred[:,0], y_pred[:,1], color='blue', label='data')\n",
    "            plt.text(0.87, 0.92, 'Iteration: {:d}\\nError: {:06.4f}'.format(iteration, loss), horizontalalignment='center', verticalalignment='center', transform=ax.transAxes, fontsize='x-large')\n",
    "        ax.legend(loc='upper left', fontsize='x-large')\n",
    "        plt.draw()\n",
    "        plt.pause(pause)\n",
    "callback = partial(visualize, ax = ax, dim = dim)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
